import torch


class TablewiseEmbeddingBagConfig:
    '''
    example:
    def prepare_tablewise_config(args, cache_ratio, ...):
        embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
        ...
        return embedding_bag_config_list
    '''

    def __init__(self,
                 num_embeddings: int,
                 cuda_row_num: int,
                 assigned_rank: int = 0,
                 buffer_size=50_000,
                 ids_freq_mapping=None,
                 initial_weight: torch.tensor = None,
                 name: str = ""):
        self.num_embeddings = num_embeddings
        self.cuda_row_num = cuda_row_num
        self.assigned_rank = assigned_rank
        self.buffer_size = buffer_size
        self.ids_freq_mapping = ids_freq_mapping
        self.initial_weight = initial_weight
        self.name = name
